{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NN Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import math\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "ACTIVATIONS = ['ReLU','Softplus','Tanh','SELU','LeakyReLU','PReLU','Sigmoid']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. MLP\n",
    "\n",
    "Multi-Layer Perceptron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MLP(nn.Module):\n",
    "    \"\"\"Multi-Layer Perceptron Class\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `in_features`: int, dimension of input.<br>\n",
    "    `out_features`: int, dimension of output.<br>\n",
    "    `activation`: str, activation function to use.<br>\n",
    "    `hidden_size`: int, dimension of hidden layers.<br>\n",
    "    `num_layers`: int, number of hidden layers.<br>\n",
    "    `dropout`: float, dropout rate.<br>\n",
    "    \"\"\"\n",
    "    def __init__(self, in_features, out_features, activation, hidden_size, num_layers, dropout):\n",
    "        super().__init__()\n",
    "        assert activation in ACTIVATIONS, f'{activation} is not in {ACTIVATIONS}'\n",
    "        \n",
    "        self.activation = getattr(nn, activation)()\n",
    "\n",
    "        # MultiLayer Perceptron\n",
    "        # Input layer\n",
    "        layers = [nn.Linear(in_features=in_features, out_features=hidden_size),\n",
    "                  self.activation,\n",
    "                  nn.Dropout(dropout)]\n",
    "        # Hidden layers\n",
    "        for i in range(num_layers - 2):\n",
    "            layers += [nn.Linear(in_features=hidden_size, out_features=hidden_size),\n",
    "                       self.activation,\n",
    "                       nn.Dropout(dropout)]\n",
    "        # Output layer\n",
    "        layers += [nn.Linear(in_features=hidden_size, out_features=out_features)]\n",
    "\n",
    "        # Store in layers as ModuleList\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.layers(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Temporal Convolutions\n",
    "\n",
    "For long time in deep learning, sequence modelling was synonymous with recurrent networks, yet several papers have shown that simple convolutional architectures can outperform canonical recurrent networks like LSTMs by demonstrating longer effective memory.\n",
    "\n",
    "**References**<br>\n",
    "-[van den Oord, A., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A. W., & Kavukcuoglu, K. (2016). Wavenet: A generative model for raw audio. Computing Research Repository, abs/1609.03499. URL: http://arxiv.org/abs/1609.03499. arXiv:1609.03499.](https://arxiv.org/abs/1609.03499)<br>\n",
    "-[Shaojie Bai, Zico Kolter, Vladlen Koltun. (2018). An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling. Computing Research Repository, abs/1803.01271. URL: https://arxiv.org/abs/1803.01271.](https://arxiv.org/abs/1803.01271)<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Chomp1d(nn.Module):\n",
    "    \"\"\" Chomp1d\n",
    "\n",
    "    Receives `x` input of dim [N,C,T], and trims it so that only\n",
    "    'time available' information is used. \n",
    "    Used by one dimensional causal convolutions `CausalConv1d`.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon`: int, length of outsample values to skip.\n",
    "    \"\"\"\n",
    "    def __init__(self, horizon):\n",
    "        super(Chomp1d, self).__init__()\n",
    "        self.horizon = horizon\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x[:, :, :-self.horizon].contiguous()\n",
    "\n",
    "\n",
    "class CausalConv1d(nn.Module):\n",
    "    \"\"\" Causal Convolution 1d\n",
    "\n",
    "    Receives `x` input of dim [N,C_in,T], and computes a causal convolution\n",
    "    in the time dimension. Skipping the H steps of the forecast horizon, through\n",
    "    its dilation.\n",
    "    Consider a batch of one element, the dilated convolution operation on the\n",
    "    $t$ time step is defined:\n",
    "\n",
    "    $\\mathrm{Conv1D}(\\mathbf{x},\\mathbf{w})(t) = (\\mathbf{x}_{[*d]} \\mathbf{w})(t) = \\sum^{K}_{k=1} w_{k} \\mathbf{x}_{t-dk}$\n",
    "\n",
    "    where $d$ is the dilation factor, $K$ is the kernel size, $t-dk$ is the index of\n",
    "    the considered past observation. The dilation effectively applies a filter with skip\n",
    "    connections. If $d=1$ one recovers a normal convolution.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `in_channels`: int, dimension of `x` input's initial channels.<br> \n",
    "    `out_channels`: int, dimension of `x` outputs's channels.<br> \n",
    "    `activation`: str, identifying activations from PyTorch activations.\n",
    "        select from 'ReLU','Softplus','Tanh','SELU', 'LeakyReLU','PReLU','Sigmoid'.<br>\n",
    "    `padding`: int, number of zero padding used to the left.<br>\n",
    "    `kernel_size`: int, convolution's kernel size.<br>\n",
    "    `dilation`: int, dilation skip connections.<br>\n",
    "    \n",
    "    **Returns:**<br>\n",
    "    `x`: tensor, torch tensor of dim [N,C_out,T] activation(conv1d(inputs, kernel) + bias). <br>\n",
    "    \"\"\"\n",
    "    def __init__(self, in_channels, out_channels, kernel_size,\n",
    "                 padding, dilation, activation, stride:int=1):\n",
    "        super(CausalConv1d, self).__init__()\n",
    "        assert activation in ACTIVATIONS, f'{activation} is not in {ACTIVATIONS}'\n",
    "        \n",
    "        self.conv       = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, \n",
    "                                    kernel_size=kernel_size, stride=stride, padding=padding,\n",
    "                                    dilation=dilation)\n",
    "        \n",
    "        self.chomp      = Chomp1d(padding)\n",
    "        self.activation = getattr(nn, activation)()\n",
    "        self.causalconv = nn.Sequential(self.conv, self.chomp, self.activation)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.causalconv(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(CausalConv1d, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TemporalConvolutionEncoder(nn.Module):\n",
    "    \"\"\" Temporal Convolution Encoder\n",
    "\n",
    "    Receives `x` input of dim [N,T,C_in], permutes it to  [N,C_in,T]\n",
    "    applies a deep stack of exponentially dilated causal convolutions.\n",
    "    The exponentially increasing dilations of the convolutions allow for \n",
    "    the creation of weighted averages of exponentially large long-term memory.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `in_channels`: int, dimension of `x` input's initial channels.<br> \n",
    "    `out_channels`: int, dimension of `x` outputs's channels.<br>\n",
    "    `kernel_size`: int, size of the convolving kernel.<br>\n",
    "    `dilations`: int list, controls the temporal spacing between the kernel points.<br>\n",
    "    `activation`: str, identifying activations from PyTorch activations.\n",
    "        select from 'ReLU','Softplus','Tanh','SELU', 'LeakyReLU','PReLU','Sigmoid'.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `x`: tensor, torch tensor of dim [N,T,C_out].<br>\n",
    "    \"\"\"\n",
    "    # TODO: Add dilations parameter and change layers declaration to for loop\n",
    "    def __init__(self, in_channels, out_channels, \n",
    "                 kernel_size, dilations,\n",
    "                 activation:str='ReLU'):\n",
    "        super(TemporalConvolutionEncoder, self).__init__()\n",
    "        layers = []\n",
    "        for dilation in dilations:\n",
    "            layers.append(CausalConv1d(in_channels=in_channels, out_channels=out_channels, \n",
    "                                        kernel_size=kernel_size, padding=(kernel_size-1)*dilation, \n",
    "                                        activation=activation, dilation=dilation))\n",
    "            in_channels = out_channels\n",
    "        self.tcn = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # [N,T,C_in] -> [N,C_in,T] -> [N,T,C_out]\n",
    "        x = x.permute(0, 2, 1).contiguous()\n",
    "        x = self.tcn(x)\n",
    "        x = x.permute(0, 2, 1).contiguous()\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TemporalConvolutionEncoder, title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Transformers"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**<br>\n",
    "- [Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, Wancai Zhang. \"Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting\"](https://arxiv.org/abs/2012.07436)<br>\n",
    "- [Haixu Wu, Jiehui Xu, Jianmin Wang, Mingsheng Long.](https://arxiv.org/abs/2106.13008)<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TransEncoderLayer(nn.Module):\n",
    "    def __init__(self, attention, hidden_size, conv_hidden_size=None, dropout=0.1, activation=\"relu\"):\n",
    "        super(TransEncoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.attention = attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1)\n",
    "        self.norm1 = nn.LayerNorm(hidden_size)\n",
    "        self.norm2 = nn.LayerNorm(hidden_size)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        new_x, attn = self.attention(\n",
    "            x, x, x,\n",
    "            attn_mask=attn_mask\n",
    "        )\n",
    "        \n",
    "        x = x + self.dropout(new_x)\n",
    "\n",
    "        y = x = self.norm1(x)\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "\n",
    "        return self.norm2(x + y), attn\n",
    "\n",
    "\n",
    "class TransEncoder(nn.Module):\n",
    "    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):\n",
    "        super(TransEncoder, self).__init__()\n",
    "        self.attn_layers = nn.ModuleList(attn_layers)\n",
    "        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None\n",
    "        self.norm = norm_layer\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        # x [B, L, D]\n",
    "        attns = []\n",
    "        if self.conv_layers is not None:\n",
    "            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                x = conv_layer(x)\n",
    "                attns.append(attn)\n",
    "            x, attn = self.attn_layers[-1](x)\n",
    "            attns.append(attn)\n",
    "        else:\n",
    "            for attn_layer in self.attn_layers:\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                attns.append(attn)\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        return x, attns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TransDecoderLayer(nn.Module):\n",
    "    def __init__(self, self_attention, cross_attention, hidden_size, conv_hidden_size=None,\n",
    "                 dropout=0.1, activation=\"relu\"):\n",
    "        super(TransDecoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.self_attention = self_attention\n",
    "        self.cross_attention = cross_attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1)\n",
    "        self.norm1 = nn.LayerNorm(hidden_size)\n",
    "        self.norm2 = nn.LayerNorm(hidden_size)\n",
    "        self.norm3 = nn.LayerNorm(hidden_size)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None):\n",
    "        x = x + self.dropout(self.self_attention(\n",
    "            x, x, x,\n",
    "            attn_mask=x_mask\n",
    "        )[0])\n",
    "        x = self.norm1(x)\n",
    "\n",
    "        x = x + self.dropout(self.cross_attention(\n",
    "            x, cross, cross,\n",
    "            attn_mask=cross_mask\n",
    "        )[0])\n",
    "\n",
    "        y = x = self.norm2(x)\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "\n",
    "        return self.norm3(x + y)\n",
    "\n",
    "\n",
    "class TransDecoder(nn.Module):\n",
    "    def __init__(self, layers, norm_layer=None, projection=None):\n",
    "        super(TransDecoder, self).__init__()\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.norm = norm_layer\n",
    "        self.projection = projection\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        if self.projection is not None:\n",
    "            x = self.projection(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class AttentionLayer(nn.Module):\n",
    "    def __init__(self, attention, hidden_size, n_head, d_keys=None,\n",
    "                 d_values=None):\n",
    "        super(AttentionLayer, self).__init__()\n",
    "\n",
    "        d_keys = d_keys or (hidden_size // n_head)\n",
    "        d_values = d_values or (hidden_size // n_head)\n",
    "\n",
    "        self.inner_attention = attention\n",
    "        self.query_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.key_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.value_projection = nn.Linear(hidden_size, d_values * n_head)\n",
    "        self.out_projection = nn.Linear(d_values * n_head, hidden_size)\n",
    "        self.n_head = n_head\n",
    "\n",
    "    def forward(self, queries, keys, values, attn_mask):\n",
    "        B, L, _ = queries.shape\n",
    "        _, S, _ = keys.shape\n",
    "        H = self.n_head\n",
    "\n",
    "        queries = self.query_projection(queries).view(B, L, H, -1)\n",
    "        keys = self.key_projection(keys).view(B, S, H, -1)\n",
    "        values = self.value_projection(values).view(B, S, H, -1)\n",
    "\n",
    "        out, attn = self.inner_attention(\n",
    "            queries,\n",
    "            keys,\n",
    "            values,\n",
    "            attn_mask\n",
    "        )\n",
    "        out = out.view(B, L, -1)\n",
    "\n",
    "        return self.out_projection(out), attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PositionalEmbedding(nn.Module):\n",
    "    def __init__(self, hidden_size, max_len=5000):\n",
    "        super(PositionalEmbedding, self).__init__()\n",
    "        # Compute the positional encodings once in log space.\n",
    "        pe = torch.zeros(max_len, hidden_size).float()\n",
    "        pe.require_grad = False\n",
    "\n",
    "        position = torch.arange(0, max_len).float().unsqueeze(1)\n",
    "        div_term = (torch.arange(0, hidden_size, 2).float() * -(math.log(10000.0) / hidden_size)).exp()\n",
    "\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.pe[:, :x.size(1)]\n",
    "\n",
    "class TokenEmbedding(nn.Module):\n",
    "    def __init__(self, c_in, hidden_size):\n",
    "        super(TokenEmbedding, self).__init__()\n",
    "        padding = 1 if torch.__version__ >= '1.5.0' else 2\n",
    "        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=hidden_size,\n",
    "                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv1d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)\n",
    "        return x\n",
    "\n",
    "class TimeFeatureEmbedding(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(TimeFeatureEmbedding, self).__init__()\n",
    "        self.embed = nn.Linear(input_size, hidden_size, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.embed(x)\n",
    "\n",
    "class DataEmbedding(nn.Module):\n",
    "    def __init__(self, c_in, exog_input_size, hidden_size, pos_embedding=True, dropout=0.1):\n",
    "        super(DataEmbedding, self).__init__()\n",
    "\n",
    "        self.value_embedding = TokenEmbedding(c_in=c_in, hidden_size=hidden_size)\n",
    "\n",
    "        if pos_embedding:\n",
    "            self.position_embedding = PositionalEmbedding(hidden_size=hidden_size)\n",
    "        else:\n",
    "            self.position_embedding = None\n",
    "\n",
    "        if exog_input_size > 0:\n",
    "            self.temporal_embedding = TimeFeatureEmbedding(input_size=exog_input_size,\n",
    "                                                        hidden_size=hidden_size)\n",
    "        else:\n",
    "            self.temporal_embedding = None\n",
    "\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "    def forward(self, x, x_mark=None):\n",
    "\n",
    "        # Convolution\n",
    "        x = self.value_embedding(x)\n",
    "\n",
    "        # Add positional (relative withing window) embedding with sines and cosines\n",
    "        if self.position_embedding is not None:\n",
    "            x = x + self.position_embedding(x)\n",
    "\n",
    "        # Add temporal (absolute in time series) embedding with linear layer\n",
    "        if self.temporal_embedding is not None:\n",
    "            x = x + self.temporal_embedding(x_mark)            \n",
    "\n",
    "        return self.dropout(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
